Library Imports
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
Template
spark = (
SparkSession.builder
.master("local")
.appName("Section 3.2 - Range Join Conditions (WIP)")
.config("spark.some.config.option", "some-value")
.getOrCreate()
)
sc = spark.sparkContext
geo_loc_table = spark.createDataFrame([
(1, 10, "foo"),
(11, 36, "bar"),
(37, 59, "baz"),
], ["ipstart", "ipend", "loc"])
geo_loc_table.toPandas()
|
ipstart |
ipend |
loc |
0 |
1 |
10 |
foo |
1 |
11 |
36 |
bar |
2 |
37 |
59 |
baz |
records_table = spark.createDataFrame([
(1, 11),
(2, 38),
(3, 50),
],["id", "inet"])
records_table.toPandas()
|
id |
inet |
0 |
1 |
11 |
1 |
2 |
38 |
2 |
3 |
50 |
Range Join Conditions
A naive approach (just specifying this as the range condition) would result in a full cartesian product and a filter that enforces the condition (tested using Spark 2.0). This has a horrible effect on performance, especially if DataFrames are more than a few hundred thousands records.
source: http://zachmoshe.com/2016/09/26/efficient-range-joins-with-spark.html
The source of the problem is pretty simple. When you execute join and join condition is not equality based the only thing that Spark can do right now is expand it to Cartesian product followed by filter what is pretty much what happens inside BroadcastNestedLoopJoin
source: https://stackoverflow.com/questions/37953830/spark-sql-performance-join-on-value-between-min-and-max?answertab=active#tab-top
Option #1
join_condition = [
records_table['inet'] >= geo_loc_table['ipstart'],
records_table['inet'] <= geo_loc_table['ipend'],
]
df = records_table.join(geo_loc_table, join_condition, "left")
df.toPandas()
|
id |
inet |
ipstart |
ipend |
loc |
0 |
1 |
11 |
11 |
36 |
bar |
1 |
2 |
38 |
37 |
59 |
baz |
2 |
3 |
50 |
37 |
59 |
baz |
df.explain()
== Physical Plan ==
BroadcastNestedLoopJoin BuildRight, LeftOuter, ((inet#252L >= ipstart#245L) && (inet#252L <= ipend#246L))
:- Scan ExistingRDD[id#251L,inet#252L]
+- BroadcastExchange IdentityBroadcastMode
+- Scan ExistingRDD[ipstart#245L,ipend#246L,loc#247]
Option #2
from bisect import bisect_right
from pyspark.sql.functions import udf
from pyspark.sql.types import LongType
geo_start_bd = spark.sparkContext.broadcast(map(lambda x: x.ipstart, geo_loc_table
.select("ipstart")
.orderBy("ipstart")
.collect()
))
def find_le(x):
'Find rightmost value less than or equal to x'
i = bisect_right(geo_start_bd.value, x)
if i:
return geo_start_bd.value[i-1]
return None
records_table_with_ipstart = records_table.withColumn(
"ipstart", udf(find_le, LongType())("inet")
)
df = records_table_with_ipstart.join(geo_loc_table, ["ipstart"], "left")
df.toPandas()
|
ipstart |
id |
inet |
ipend |
loc |
0 |
37 |
2 |
38 |
59 |
baz |
1 |
37 |
3 |
50 |
59 |
baz |
2 |
11 |
1 |
11 |
36 |
bar |
df.explain()
== Physical Plan ==
*(4) Project [ipstart#272L, id#251L, inet#252L, ipend#246L, loc#247]
+- SortMergeJoin [ipstart#272L], [ipstart#245L], LeftOuter
:- *(2) Sort [ipstart#272L ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(ipstart#272L, 200)
: +- *(1) Project [id#251L, inet#252L, pythonUDF0#281L AS ipstart#272L]
: +- BatchEvalPython [find_le(inet#252L)], [id#251L, inet#252L, pythonUDF0#281L]
: +- Scan ExistingRDD[id#251L,inet#252L]
+- *(3) Sort [ipstart#245L ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(ipstart#245L, 200)
+- Scan ExistingRDD[ipstart#245L,ipend#246L,loc#247]